import torch
from einops import rearrange

# rearrange: 用于对张量的维度进行重新变换排序，可用于替换pytorch中的reshape，view，transpose和permute等操作

# 维度合并
x = torch.randn((3, 4, 5))
y = rearrange(x, "i j k ->(i j) k")  # torch.Size([12, 5])

# 维度拆分
x = torch.randn(4, 512)
y = rearrange(x, "i (j k)->i j k", j=64)  # torch.Size([4, 64, 8])
x = torch.randn(4, 8)
y = rearrange(x, "i (j k)-> i j k", j=1)

# 维度交换
x = torch.randn(3, 4, 5)
y = rearrange(x, "i j k -> j i k")

# 维度重组
x = torch.randn(3, 512, 5)
y = rearrange(x, "i (n m) k ->i n (k m)", n=64)  # torch.Size([3, 64, 40])

pass
